
// This file is part of Module Proxy.

// Module Proxy is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// Module Proxy is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with Module Proxy.  If not, see <https://www.gnu.org/licenses/>.


//         Copyright (C) 2021 - 2030  关中麦客  
//         All rights reserved
//
//         mod_socket.rs
//         socket代理模块
//
//         created by 关中麦客 1036038462@qq.com

use bytes::{Bytes, Buf, BufMut, BytesMut};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use hyper::{Body, Request, Response, Error};
use hyper::body;
// use serde_derive::Deserialize;
use std::sync::Mutex;
use lazy_static::lazy_static;
use std::collections::HashMap;

//全局缓存
lazy_static! {
    //第一层key：mod_name, 第二次key：addr
    static ref PROXY: Mutex<HashMap<String, HashMap<String, u32>>> = Mutex::new(HashMap::new());
}

pub const ERR_TIMEOUT: u32 = 60;        //转发错误超时时间（60秒） 

/// 构建转发集群
pub fn init_cluster()
{
    let vec = super::conf::all_module_socket(); //获得所有socket module配置
    for one_config in vec
    {
        let mod_name = one_config.module;      //模块名
        let addr = one_config.forward_pass;    //地址（集群配置）

        let mut cluster_map: HashMap<String, u32> = HashMap::new(); //模块名对应的转发集群队列

        let cluster: Vec<&str> = addr.split(",").collect(); //逗号分割集群地址
        for one_item in cluster
        {
            let addr = one_item.trim().to_string();     //地址（格式 127.0.0.1:8888）
            cluster_map.insert(addr, 0);                //压入转发集群队列
        }
        
        PROXY.lock().unwrap().insert(mod_name, cluster_map);   
    }
}

/// 判断模块名是否存在于socket转发配置中
pub fn exist(mod_name: &str) -> bool
{
    if let Some(_) = PROXY.lock().unwrap().get(mod_name)
    {
        return true;
    }

    false
}

/// 从转发集群中获得一个可转发地址
fn get_addr_from_cluster(mod_name: &str) -> String
{
    let mut ok_vec: Vec<String> = vec!{};  //可转发地址集合

    if let Some(addr_map) = PROXY.lock().unwrap().get(mod_name)
    {
        //过滤出可转发地址
        for (addr, last_err_time) in addr_map
        {
            if last_err_time + ERR_TIMEOUT < super::util::sec_timestamp() //已过超时时间(正常)
            {
                ok_vec.push(addr.clone());
            }
        }

        //如果可转发地址队列空，将所有地址置入可转发队列
        if ok_vec.len() == 0
        {
            for (addr, _) in addr_map
            {
                ok_vec.push(addr.clone());
            }
        }

        //如果可转发地址队列仍然空（配置问题），返回空字符串让转发出错
        if ok_vec.len() == 0
        {
            return "".to_string();
        }

        //从可转发队列中随机选择一个
        let index = super::util::rand(0, ok_vec.len() as u32) as usize;
        return ok_vec[index].clone();
    }

    "".to_string()  //配置文件加载问题，返回空字符串让转发出错
}

/// 修改缓存
fn set_status(mod_name: String, addr: String)
{
    if let Some(addr_map) = PROXY.lock().unwrap().get_mut(&mod_name)
    {
        addr_map.insert(addr.to_string(), super::util::sec_timestamp());      //修改缓存
    }
}

/// 打印缓存
pub fn print_cache()
{
    log::debug!("[module socket cluster]");
    for (mod_name, addr_map) in PROXY.lock().unwrap().iter()
    {
        for (addr, last_err_time) in addr_map
        {
            if last_err_time + ERR_TIMEOUT < super::util::sec_timestamp() //已过超时时间
            {
                log::debug!("{} : {} -- OK", mod_name, addr);
            }
            else
            {
                log::debug!("{} : {} -- REJECT", mod_name, addr);
            }
        }
    }
}

/// socket代理
pub async fn socket(request: Request<Body>, mod_name: &str) -> Result<Response<Body>, Error>
{
    let addr = get_addr_from_cluster(mod_name); //转发地址

    //body json
    let (_, body) = request.into_parts();
    let json = match body::to_bytes(body).await
    {
        Ok(bytes) => bytes.iter().cloned().collect::<Vec<u8>>().into(),
        Err(_) => Bytes::new(),
    };

    match TcpStream::connect(&addr).await   //创建socket
    {
        Ok(mut socket) => 
        {
            //发送长度行
            let len_msg = format!("{:>10}\r\n", json.len());
            //组装发送消息
            let mut buf = BytesMut::new();
            buf.put(len_msg.as_bytes());
            buf.put(json);
            let msg_bytes = Bytes::from(buf);

            //发送
            if let Err(err) = socket.write_all(&msg_bytes).await
            {
                log::warn!("{} - socket forward error: {}", mod_name, err);
                set_status(mod_name.to_string(), addr);         //转发错误记录到缓存
                print_cache();  //打印缓存
                return Ok(super::response::rsp_500().await);
            }

            //接收
            match read(socket).await
            {
                Some(bytes) => return Ok(super::response::socket_rsp_200(bytes).await),
                None => 
                {
                    log::warn!("{} - socket forward read error", mod_name);
                    set_status(mod_name.to_string(), addr);     //转发错误记录到缓存 
                    print_cache();  //打印缓存
                    return Ok(super::response::rsp_500().await);
                }
            }          
        },
        Err(err) => 
        {
            log::error!("{} - socket forward connect error: {}", &addr, err);
            Ok(super::response::rsp_500().await)
        }
    }   
}

/// socket读
pub async fn read(mut socket: TcpStream) -> Option<Bytes>
{
    let mut len_buf: [u8; 12] = [0; 12];        //12字节buf
    match socket.read(&mut len_buf).await       //socket读缓长度12字节
    {
        Ok(c_size) =>
        {
            //读不到起始的12字节长度
            if c_size != 12   
            {
                return None;
            }

            //解析rsp_json的长度
            let len_line = String::from_utf8(len_buf.to_vec()).unwrap();
            let len_str = len_line.trim();
            let rsp_json_len = len_str.parse::<usize>().unwrap(); //rsp_json的长度

            let mut buffer = BytesMut::new();           //读取rsp_json缓冲区
            let mut read_buf: [u8; 4096] = [0; 4096];   //socket读缓冲

            loop 
            {
                //读取完整
                if buffer.len() >= rsp_json_len
                {
                    let json_bytes = buffer.to_bytes(); 
                    return Some(json_bytes);
                }

                //从socket中读取
                if let Ok(c_size) = socket.read(&mut read_buf).await
                {
                    buffer.put(&read_buf[..c_size]);  //本次读取的数据，合并到buffer
                }
            }
        },
        Err(err) =>
        {
            log::warn!("socket read error: {}", err);
            None
        }
    }
}

//---------------------------------------------------------
// test
//---------------------------------------------------------
